#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project ：PythonProjects
@File    ：bp_chinese.py
@IDE     ：PyCharm
@Author  ：pipibao
@Date    ：2021/7/5 下午9:20
'''
# 载入模块
import sys
import os
import time
import random
import numpy as np
import tensorflow.compat.v1 as tf

tf.disable_v2_behavior()
import cv2 as cv

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # 设置警告等级

# 设置基本参数
SIZE = 1024
WIDTH = 32
HEIGHT = 32
NUM_CLASSES = 31  # 总共是31个省份
iterations = 1000

# 设置存储模型的地址
SAVER_DIR = "train_saver/chinese/"  # 自己的路径
PROVINCES = (
    "川", "鄂", "赣", "甘", "贵", "桂", "黑", "沪", "冀", "津", "京", "吉", "辽", "鲁", "蒙", "闽", "宁", "青", "琼", "陕", "苏", "晋", "皖",
    "湘",
    "新", "豫", "渝", "粤", "云", "藏", "浙")
nProvinceIndex = 0
time_begin = time.time()

# 定义输入节点，对应于图像像素值矩阵集合和图像标签（即所代表的数字）
x = tf.placeholder(tf.float32, shape=[None, SIZE])  # None表示batch size的大小
y_ = tf.placeholder(tf.float32, shape=[None, NUM_CLASSES])  # 输出标签的占位
x_image = tf.reshape(x, [-1, WIDTH, HEIGHT, 1])  # 生成一个四维的数组


# 定义卷积函数
def conv_layer(inputs, W, b, conv_strides, kernel_size, pool_strides, padding):
    L1_conv = tf.nn.conv2d(inputs, W, strides=conv_strides, padding=padding)  # 卷积操作
    L1_relu = tf.nn.relu(L1_conv + b)  # 激活函数RELU
    return tf.nn.max_pool(L1_relu, ksize=kernel_size, strides=pool_strides, padding='SAME')


# 定义全连接函数
def full_connect(inputs, W, b):
    return tf.nn.relu(tf.matmul(inputs, W) + b)


def average(seq):
    return float(sum(seq)) / len(seq)


# 训练模型
if __name__ == "__main__":
    # 第一次遍历图片目录是为了获取图片总数
    input_count = 0
    for i in range(0, 31):
        dir = 'train\\%s\\' % i  # 自己的路径
        for root, dirs, files in os.walk(dir):
            for filename in files:
                input_count = input_count + 1
    # 定义对应维数和各维长度的数组
    input_images = np.array([[0] * SIZE for i in range(input_count)])  # 生成一个input_count行，SIZE列的全零二维数组
    input_labels = np.array([[0] * NUM_CLASSES for i in range(input_count)])  # 生成一个input_count行，NUM_CLASSES列的全零二维数组
    # 第二次遍历图片目录是为了生成图片数据和标签
    index = 0
    for i in range(0, 31):
        dir = 'train\\%s\\' % i
        a = 0
        for root, dirs, files in os.walk(dir):
            for filename in files:
                filename = dir + filename
                img = cv.imread(filename, 0)
                print(filename)
                print(a)
                # cv.imshow('threshold',img)
                # cv.waitKey(0)
                height = img.shape[0]  # 行数
                width = img.shape[1]  # 列数
                a = a + 1
                for h in range(0, height):
                    for w in range(0, width):
                        m = img[h][w]
                        if m > 150:
                            input_images[index][w + h * width] = 1
                        else:
                            input_images[index][w + h * width] = 0
                input_labels[index][i] = 1
                index = index + 1
    # 第一次遍历图片目录是为了获得图片总数
    val_count = 0
    for i in range(0, 31):
        dir = 'train\\%s\\' % i
        for root, dirs, files in os.walk(dir):
            for filename in files:
                val_count = val_count + 1
    # 定义对应维数和各维长度的数组
    val_images = np.array([[0] * SIZE for i in range(val_count)])  # 生成一个input_count行，SIZE列的全零二维数组
    val_labels = np.array([[0] * NUM_CLASSES for i in range(val_count)])  # 生成一个input_count行，NUM_CLASSES列的全零二维数组
    # 第二次遍历图片目录是为了生成图片数据和标签
    index = 0
    for i in range(0, 31):
        dir = 'train\\%s\\' % i
        for root, dirs, files in os.walk(dir):
            for filename in files:
                filename = dir + filename
                img = cv.imread(filename, 0)
                height = img.shape[0]  # 行数
                width = img.shape[1]  # 列数
                for h in range(0, height):
                    for w in range(0, width):
                        m = img[h][w]
                        if m > 150:
                            val_images[index][w + h * width] = 1
                        else:
                            val_images[index][w + h * width] = 0
                val_labels[index][i] = 1
                index = index + 1
    with tf.Session() as sess:
        # 第一个卷积层
        W_conv1 = tf.Variable(tf.truncated_normal([5, 5, 1, 12], stddev=0.1), name="W_conv1")
        b_conv1 = tf.Variable(tf.constant(0.1, shape=[12]), name="b_conv1")  # 生成偏置项，并初始化
        conv_strides = [1, 1, 1, 1]  # 行，列的卷积步长均为1
        kernel_size = [1, 2, 2, 1]  # 池化层卷积核的尺寸为2*2
        pool_strides = [1, 2, 2, 1]  # 池化行，列步长为2
        L1_pool = conv_layer(x_image, W_conv1, b_conv1, conv_strides, kernel_size, pool_strides,
                             padding='SAME')  # 第一层卷积池化的输出 ,x_image为输入（后文代码中输入）

        # 第二个卷积层
        W_conv2 = tf.Variable(tf.truncated_normal([5, 5, 12, 24], stddev=0.1), name="W_conv2")
        b_conv2 = tf.Variable(tf.constant(0.1, shape=[24]), name="b_conv2")
        conv_strides = [1, 1, 1, 1]
        kernel_size = [1, 2, 2, 1]
        pool_strides = [1, 2, 2, 1]
        L2_pool = conv_layer(L1_pool, W_conv2, b_conv2, conv_strides, kernel_size, pool_strides, padding="SAME")

        # 全连接层
        W_fc1 = tf.Variable(tf.truncated_normal([8 * 8 * 24, 512], stddev=0.1), name="W_fc1")
        b_fc1 = tf.Variable(tf.constant(0.1, shape=[512]), name="b_fc1")
        h_pool2_flat = tf.reshape(L2_pool, [-1, 8 * 8 * 24])  # 将第二次池化的二维特征图排列成一维的一个数组 全连接相当于一维的数组
        h_fc1 = full_connect(h_pool2_flat, W_fc1, b_fc1)  # 进行全连接操作

        # dropout
        keep_prob = tf.placeholder(tf.float32)
        h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

        # readout层
        W_fc2 = tf.Variable(tf.truncated_normal([512, NUM_CLASSES], stddev=0.1), name="W_fc2")
        b_fc2 = tf.Variable(tf.constant(0.1, shape=[NUM_CLASSES]), name="b_fc2")

        # 定义优化器和训练OP
        y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2  # 最后的输出层，因为是全连接，相当于每个神经元与权重相乘再加偏移
        cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))  # 交叉熵损失函数
        train_step = tf.train.AdamOptimizer((1e-5)).minimize(cross_entropy)
        correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

        # 初始化saver
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())  # 初始化所有变量
        time_elapsed = time.time() - time_begin  # 运行时间
        print("读取图片文件耗费时间：%d秒" % time_elapsed)
        time_begin = time.time()
        print("一共读取了%s个训练图像，%s个标签" % (input_count, input_count))

        # 设置每次训练操作的输入个数和迭代次数，这里为了支持任意图片总数，定义了一个余数remainder，譬如，如果每次训练训练操作的输入个数为60，图片总数为150张，则前面两次各输入60张，最后一次输入30张（余数30）
        batch_size = 64  # 每次训练的图片数
        iterations = iterations  # 迭代次数
        batches_count = int(input_count / batch_size)
        remainder = input_count % batch_size
        print("训练数据集分成%s批，前面每批%s个数据，最后一批%s个数据" % (batches_count + 1, batch_size, remainder))

        # 执行训练迭代
        for it in range(iterations):
            # 这里的关键是要把输入数组转为np.array
            sum_loss = []
            for n in range(batches_count):
                loss, out = sess.run([cross_entropy, train_step],
                                     feed_dict={x: input_images[n * batch_size:(n + 1) * batch_size],
                                                y_: input_labels[n * batch_size:(n + 1) * batch_size],
                                                keep_prob: 0.5})  # feed_dict相当于一次喂进去的数据，x表示输入，前面已经将输入的图片转化为input_image数组形式了
                sum_loss.append(loss)
            if remainder > 0:
                start_index = batches_count * batch_size
                loss, out = sess.run([cross_entropy, train_step],
                                     feed_dict={x: input_images[start_index:input_count - 1],
                                                y_: input_labels[start_index:input_count - 1], keep_prob: 0.5})
                sum_loss.append(loss)
            avg_loss = average(sum_loss)

            # 每完成5次迭代，判断准确度是否已达到100%，达到则退出迭代循环
            iterate_accuracy = 0
            if it % 5 == 0:
                loss1, iterate_accuracy = sess.run([cross_entropy, accuracy],
                                                   feed_dict={x: val_images, y_: val_labels, keep_prob: 1.0})
                print('第%d次训练迭代：准确率 %0.5f%% ' % (
                    it, iterate_accuracy * 100) + '    损失值为：%s' % loss + '    测试损失值：%s' % loss1)
                if iterate_accuracy >= 0.99:
                    break

        # 完成训练，并输出训练时间
        print('完成训练')
        time_elapsed = time.time() - time_begin
        print("训练耗费时间：%d秒" % time_elapsed)
        time_begin = time.time()

        # 保存训练结果
        if not os.path.exists(SAVER_DIR):
            print('不存在训练数据保存目录，现在创建保存目录')
            os.makedirs(SAVER_DIR)
        saver_path = saver.save(sess, "%smodel.ckpt" % (SAVER_DIR))
        print("保存路径为：", saver_path)
